{
    "cells": [
        {
            "cell_type": "markdown",
            "id": "e45f9b60-cd6b-4c15-958f-1feca5438128",
            "metadata": {},
            "source": [
                "# SQL Index Demo\n",
                "\n",
                "Demo where table contains context."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 1,
            "id": "fbd7317b",
            "metadata": {},
            "outputs": [],
            "source": [
                "import logging\n",
                "import sys\n",
                "\n",
                "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
                "logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 2,
            "id": "107396a9-4aa7-49b3-9f0f-a755726c19ba",
            "metadata": {},
            "outputs": [],
            "source": [
                "from gpt_index import GPTSQLStructStoreIndex, SQLDatabase, SimpleDirectoryReader, WikipediaReader, Document\n",
                "from gpt_index.indices.struct_store import SQLContextContainerBuilder\n",
                "from IPython.display import Markdown, display"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "461438c8-302d-45c5-8e69-16ad604686d1",
            "metadata": {},
            "source": [
                "### Create Database Schema + Test Data\n",
                "\n",
                "Here we introduce a toy scenario where there are 100 tables (too big to fit into the prompt)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 3,
            "id": "a370b266-66f5-4624-bbf9-2ad57f0511f8",
            "metadata": {},
            "outputs": [],
            "source": [
                "from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select, column"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 30,
            "id": "ea24f794-f10b-42e6-922d-9258b7167405",
            "metadata": {},
            "outputs": [],
            "source": [
                "engine = create_engine(\"sqlite:///:memory:\")\n",
                "metadata_obj = MetaData(bind=engine)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 31,
            "id": "b4154b29-7e23-4c26-a507-370a66186ae7",
            "metadata": {},
            "outputs": [],
            "source": [
                "# create city SQL table\n",
                "table_name = \"city_stats\"\n",
                "city_stats_table = Table(\n",
                "    table_name,\n",
                "    metadata_obj,\n",
                "    Column(\"city_name\", String(16), primary_key=True),\n",
                "    Column(\"population\", Integer),\n",
                "    Column(\"country\", String(16), nullable=False),\n",
                ")\n",
                "all_table_names = [\"city_stats\"]\n",
                "# create a ton of dummy tables\n",
                "n = 100\n",
                "for i in range(n):\n",
                "    tmp_table_name = f\"tmp_table_{i}\"\n",
                "    tmp_table = Table(\n",
                "        tmp_table_name,\n",
                "        metadata_obj,\n",
                "        Column(f\"tmp_field_{i}_1\", String(16), primary_key=True),\n",
                "        Column(f\"tmp_field_{i}_2\", Integer),\n",
                "        Column(f\"tmp_field_{i}_3\", String(16), nullable=False),\n",
                "    )\n",
                "    all_table_names.append(f\"tmp_table_{i}\")\n",
                "\n",
                "metadata_obj.create_all()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "4768bcb4-c40e-4d5d-8d70-7cb3228b50ab",
            "metadata": {},
            "outputs": [],
            "source": [
                "# print tables\n",
                "metadata_obj.tables.keys()"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "4c0eb518-5da3-4215-8280-0776d07806a0",
            "metadata": {},
            "source": [
                "We introduce some test data into the `city_stats` table"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 33,
            "id": "d15192b6-99f9-4f72-b637-82e885ea057f",
            "metadata": {},
            "outputs": [],
            "source": [
                "from sqlalchemy import insert\n",
                "rows = [\n",
                "    {\"city_name\": \"Toronto\", \"population\": 2930000, \"country\": \"Canada\"},\n",
                "    {\"city_name\": \"Tokyo\", \"population\": 13960000, \"country\": \"Japan\"},\n",
                "    {\"city_name\": \"Chicago\", \"population\": 2679000, \"country\": \"United States\"},\n",
                "    {\"city_name\": \"Seoul\", \"population\": 9776000, \"country\": \"South Korea\"},\n",
                "]\n",
                "for row in rows:\n",
                "    stmt = insert(city_stats_table).values(**row)\n",
                "    with engine.connect() as connection:\n",
                "        cursor = connection.execute(stmt)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 34,
            "id": "bfc2e4a4-e11d-4d8f-bf1f-7f777a1dc6e2",
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]\n"
                    ]
                }
            ],
            "source": [
                "with engine.connect() as connection:\n",
                "    cursor = connection.exec_driver_sql(\"SELECT * FROM city_stats\")\n",
                "    print(cursor.fetchall())"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "1c09089a-6bcd-48db-8120-a84c8da3f82e",
            "metadata": {
                "tags": []
            },
            "source": [
                "### Using GPT Index to Store Table Schema Context"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 35,
            "id": "611319e5-d3c2-4286-a84f-ed2459896c58",
            "metadata": {},
            "outputs": [],
            "source": [
                "from gpt_index import GPTSQLStructStoreIndex, SQLDatabase, GPTSimpleVectorIndex\n",
                "from gpt_index.indices.struct_store import SQLContextContainerBuilder"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 36,
            "id": "3fc2dfab-90ea-4f01-9e28-d21fdc5f0758",
            "metadata": {},
            "outputs": [],
            "source": [
                "sql_database = SQLDatabase(engine)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "89f6f1d1-a022-43d7-b135-a79ec9407956",
            "metadata": {},
            "outputs": [],
            "source": [
                "sql_database.table_info"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "331ff0ce-9131-4680-a5f2-3f41c73e018e",
            "metadata": {},
            "source": [
                "We dump the table schema information into a vector index. The vector index is stored within the context builder for future use."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "1e86d548-a3f4-436a-a754-5247871ebe55",
            "metadata": {
                "scrolled": true,
                "tags": []
            },
            "outputs": [],
            "source": [
                "# build a vector index from the table schema information\n",
                "context_builder = SQLContextContainerBuilder(sql_database)\n",
                "table_schema_index = context_builder.derive_index_from_context(\n",
                "    GPTSimpleVectorIndex,\n",
                ")"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "a743f365-21c6-4eae-a2f4-fc72d4199daa",
            "metadata": {},
            "outputs": [],
            "source": [
                "# NOTE: not ingesting any unstructured documents atm\n",
                "index = GPTSQLStructStoreIndex.from_documents(\n",
                "    [],\n",
                "    sql_database=sql_database, \n",
                "    table_name=\"city_stats\",\n",
                ")"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "051a171f-8c97-40ed-ae17-4e3fa3785487",
            "metadata": {},
            "source": [
                "### Query Index"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "91139712-f232-47e1-9683-cbbd49cd331b",
            "metadata": {},
            "source": [
                "Here we show a natural language query. \n",
                "1. We first query for the right table schema. Note that we build a context container during query-time.\n",
                "2. Given this context container, we execute the NL query against the db."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 44,
            "id": "eabededd-3c17-45b7-aabc-06a2457bc3cb",
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "INFO:root:> [query] Total LLM token usage: 135 tokens\n",
                        "> [query] Total LLM token usage: 135 tokens\n",
                        "INFO:root:> [query] Total embedding token usage: 23 tokens\n",
                        "> [query] Total embedding token usage: 23 tokens\n",
                        "\n",
                        "Table 'city_stats':\n",
                        "city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))\n"
                    ]
                }
            ],
            "source": [
                "query_str = \"Which city has the highest population?\"\n",
                "context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True)\n",
                "context_container = context_builder.build_context_container()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 45,
            "id": "a80ee856-6ac3-4b37-b390-be583024bed4",
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/markdown": [
                            "<b>\n",
                            "Table 'city_stats':\n",
                            "city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))</b>"
                        ],
                        "text/plain": [
                            "<IPython.core.display.Markdown object>"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                }
            ],
            "source": [
                "display(Markdown(f\"<b>{context_container.context_str}</b>\"))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 40,
            "id": "64671ddc-9768-40c2-8898-ab7c0cf10917",
            "metadata": {},
            "outputs": [
                {
                    "name": "stdout",
                    "output_type": "stream",
                    "text": [
                        "INFO:root:> Table desc str: \n",
                        "Table 'city_stats':\n",
                        "city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))\n",
                        "> Table desc str: \n",
                        "Table 'city_stats':\n",
                        "city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16))\n",
                        "INFO:root:> [query] Total LLM token usage: 134 tokens\n",
                        "> [query] Total LLM token usage: 134 tokens\n",
                        "INFO:root:> [query] Total embedding token usage: 0 tokens\n",
                        "> [query] Total embedding token usage: 0 tokens\n"
                    ]
                }
            ],
            "source": [
                "response = index.query(query_str, sql_context_container=context_container)"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "5dc2f7bf-6f6c-42ba-8f42-47afea6606ad",
            "metadata": {},
            "source": [
                "We can also use codewords during the NL query! "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 41,
            "id": "25c11645-56bd-433a-85f4-420413f8970d",
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "\"[('Tokyo',)]\""
                        ]
                    },
                    "execution_count": 41,
                    "metadata": {},
                    "output_type": "execute_result"
                }
            ],
            "source": [
                "str(response)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "3f72abc6-54d7-4f85-abf8-32978d94f558",
            "metadata": {},
            "outputs": [],
            "source": [
                "response.extra_info"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "gpt_retrieve_venv",
            "language": "python",
            "name": "gpt_retrieve_venv"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.9.16"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 5
}